# coding=utf-8

from sqlalchemy import create_engine, Column, Integer, String, Boolean, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy import and_

from log import debug

# DB_USER = "aliyun_computer"
# DB_PASS = "sobey@2016"
# DB_HOST = "172.17.0.3"
# DB_PORT = "3306"
# DB_NAME = "aliyun_computer_db"
# DATABASE_URI = 'mysql+mysqldb://%s:%s@%s:%s/%s' % (DB_USER, DB_PASS, DB_HOST, DB_PORT, DB_NAME)
DATABASE_URI = ""
BaseModel = declarative_base()

# 将sqlAlchemy中的对象转换为dict
def to_dict(self):
    return {c.name: getattr(self, c.name, None) for c in self.__table__.columns if c.name!="Status"}

BaseModel.to_dict = to_dict

class CommonDBExecutor(object):

    def __init__(self, db_url=DATABASE_URI, table=None):
        self.db = create_engine(db_url, pool_size=100, max_overflow=200, pool_recycle=3600, encoding='utf8')
        BaseModel.metadata.create_all(bind=self.db)
        DBSession = sessionmaker(bind=self.db)
        self.session = DBSession()
        if table:
            self.tb = table
        else:
            raise Exception("Failed to get table for executor.")

    def __del__(self):
        self.session.close_all()

    def insert(self, **kwargs):
        service = self.tb(**kwargs)
        self.session.add(service)
        self.session.commit()

    # 这里支持在表列名前面加 "not_" 将过滤条件设置为不等于
    # 这个列名对应的值
    def _filter_kwargs_map(self, filter_table):
        new_filter_list = []
        for key, val in filter_table.items():
            filter_expression = None
            if len(key) > 4 and "not_" == key[0:4]:
                class_key = getattr(self.tb, key[4:])
                filter_expression = class_key!=val
            else:
                class_key = getattr(self.tb, key)
                filter_expression = class_key==val
            new_filter_list.append(filter_expression)
        return new_filter_list

    def query(self, **kwargs):
        ret = None
        data = None
        if kwargs:
            new_filter_list = self._filter_kwargs_map(kwargs)
            if len(new_filter_list) > 1:
                data = self.session.query(self.tb).filter(and_(*new_filter_list)).all()
            else:
                data = self.session.query(self.tb).filter(*new_filter_list).all()
        else:
            data = self.session.query(self.tb).all()
        if isinstance(data, list):
            ret = [d.to_dict() for d in data]
        else:
            ret = data.to_dict()
        return ret

    def update(self, update_dict={}, **kwargs):

        ret = None
        if kwargs:
            new_filter_list = self._filter_kwargs_map(kwargs)
            if len(new_filter_list) > 1:
                ret = self.session.query(self.tb).filter(and_(*new_filter_list)).update(update_dict)
            else:
                ret = self.session.query(self.tb).filter(*new_filter_list).update(update_dict)
        else:
            ret = self.session.query(self.tb).update(update_dict)

        self.session.commit()

        return ret

    def delete(self, **kwargs):

        ret = None
        if kwargs:
            new_filter_list = self._filter_kwargs_map(kwargs)
            if len(new_filter_list) > 1:
                ret = self.session.query(self.tb).filter(and_(*new_filter_list)).delete()
            else:
                ret = self.session.query(self.tb).filter(*new_filter_list).delete()
        else:
            ret = self.session.query(self.tb).delete()

        self.session.commit()

        return ret

def notify(user, passwd, host, port, name, engine="mysql"):
    global DATABASE_URI
    # DB_USER, DB_PASS, DB_HOST, DB_PORT, DB_NAME
    # engine = create_engine("mysql://scrat:"+'scratdb123'+"@{}/{}?charset=utf8".format(DB_ADDRESS,stockdb),encoding='utf-8')
    n_db_url = 'mysql+mysqldb://%s:%s@%s:%s/%s?charset=utf8mb4' % \
                (user, passwd, host, port, name)

    if DATABASE_URI != n_db_url:
        # need to reinit the db
        DATABASE_URI = n_db_url

def get_db_uri():
    global DATABASE_URI
    return DATABASE_URI
